"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


analyse the results of the synthetic hierarchic example and plots Fig. 3 
    

"""

import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))


import glob
import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork
from TemporalStability import norm_mutual_information
from copy import deepcopy

import matplotlib.pyplot as plt

plt.style.use('alex_paper')

import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 1.7

from collections import Counter

datadir = '../paper_data/synthtempnet/hierarchic_net_paper'

clusterdir = os.path.join(datadir, 'clustersI')

figdir = '../figures/paper_figures'

file_prefix = 'synthtemp_hiera_paper_example'

net_file = os.path.join(datadir,'hiera_tempnet_paper_0')

sim_param_file =  os.path.join(datadir, 'sim_param.pickle')

num_sym = 10
raise Exception

#%%

files = os.listdir(clusterdir + '_0')

interval_starts = set()
interval_ends = set()
tau_ws = set()
for file in files:
    extracts = os.path.splitext(file)[0].split('_')
    interval_starts.add(int(extracts[-3]))
    interval_ends.add(int(extracts[-1]))
    for extract in os.path.splitext(file)[0].split('_'):
        if extract.startswith('w'):
            tau_ws.add(float(extract[1:]))
            
        
            
#%%% net work file

net = ContTempNetwork.load(net_file)
sim_param = pd.read_pickle(sim_param_file)
net._compute_time_grid()

num_intervals = max(interval_ends)

int_times = np.linspace(sim_param['t_start'],sim_param['t_end'],num_intervals+1)


#                                                    int_start))).toarray()
#        
#%%
part1 = [set(range(g*sim_param['n_per_group'],(g+1)*sim_param['n_per_group'])) \
         for g in range(sim_param['n_groups'])]
part2 = [part1[i].union(part1[i+1]).union(part1[i+2]) for i in range(0,sim_param['n_groups'],
                                                                                 3)]
part3 = [part2[i].union(part2[i+1]).union(part2[i+2]) for i in range(0,len(part2),
                                                                                 3)]
int_start = 0

res_tau = {}
#%%

load_rev = False
def analyse_res(res, typ, res_dict):
    
    
    mostc_clust = [set(c) for c in res['clust_counter_' + typ].most_common(1)[0][0]]
    best_clust = res['best_cluster_' + typ]
    nmis = np.array([norm_mutual_information(best_clust, part1),
            norm_mutual_information(best_clust, part2),
            norm_mutual_information(best_clust, part3)])
    max_nmis = np.argmax(nmis)
    
    nmi_eq1 = np.argwhere(np.array(nmis == 1))
    
    
    res_dict[typ]['avg_nclust'].append(res['avg_nclust_' + typ])
    res_dict[typ]['nvarinf'].append(res['nvarinf_' + typ])
    res_dict[typ]['mostc_clust'].append(mostc_clust)
    res_dict[typ]['max_nmis'].append(max_nmis)
    res_dict[typ]['nmi_1'].append(nmis[0])
    res_dict[typ]['nmi_2'].append(nmis[1])
    res_dict[typ]['nmi_3'].append(nmis[2])
    res_dict[typ]['nmi_eq1'].append(nmi_eq1)
    res_dict[typ]['best_cluster'].append(res['best_cluster_' +typ])
    res_dict[typ]['best_stab'].append(res['best_stab_' + typ])
    res_dict[typ]['bestc_nclust'].append(len(res['best_cluster_' +typ]))
    
for tau_w in tau_ws:
    print(tau_w)
    
    res_tau[tau_w] = {}
    
    for i in range(num_sym):
        print(i)
        
        res_tau[tau_w][i] = {}
        filedir = clusterdir + f'_{i}'
        files = os.listdir(filedir )

        filt_files = []
        filt_files_lin = []
        for file in files:
            if glob.fnmatch.fnmatch(file, f'clusters_{file_prefix}_{i}_' + 'tau_w{0:.3e}_PT_{1:06d}_to*'.format(tau_w, int_start)):
                filt_files.append(file)
            elif glob.fnmatch.fnmatch(file,  f'clusters_{file_prefix}_{i}_' + 'tau_w{0:.3e}_PT_lin_{1:06d}_to*'.format(tau_w, int_start)):
                filt_files_lin.append(file)
            
        filt_files = sorted(filt_files)
        filt_files_lin = sorted(filt_files_lin)
        
        t2s = []
        
        res_dict = {'avg_nclust' : [],
               'nvarinf' : [],
               'mostc_clust' : [],
               'nmi_1' : [],
               'nmi_2' : [],
               'nmi_3' : [],
               'max_nmis' : [],
               'nmi_eq1' : [],
               'best_cluster' : [],
               'best_stab' : [],
               'bestc_nclust' : []}
        
        clust_res = {'sym' : deepcopy(res_dict),
                     }
        
        
                    
        for file in filt_files:
            extracts = os.path.splitext(file)[0].split('_')
            int_end = int(extracts[-1])
            t2s.append(int_times[int_end])
            
            res = pd.read_pickle(os.path.join(filedir, file))
            
            for typ in clust_res.keys():
                analyse_res(res, typ, clust_res)
            
        
            
                
              
                
        clust_res_lin = {'sym' : deepcopy(res_dict),
                      }
        
        
        for file in filt_files_lin:
            extracts = os.path.splitext(file)[0].split('_')
            int_end = int(extracts[-1])
        
            
            res = pd.read_pickle(os.path.join(filedir, file))
            
            for typ in clust_res_lin.keys():
                analyse_res(res, typ, clust_res_lin)
            
        
        
            
    
        res_tau[tau_w][i] = {'clust_res' : clust_res,
                          'clust_res_lin' : clust_res_lin,
                          't2s' : t2s }


#%% aggregate results

aggr_res = {}
for tau_w in res_tau.keys():
    best_clusts = []
    best_clusts_lin = []
    nvarinfs = []
    nvarinfs_lin = []
    for i in res_tau[tau_w].keys(): # loop over simulations
        best_clusts.append(res_tau[tau_w][i]['clust_res']['sym']['best_cluster'])
        best_clusts_lin.append(res_tau[tau_w][i]['clust_res_lin']['sym']['best_cluster'])
        nvarinfs.append(res_tau[tau_w][i]['clust_res']['sym']['nvarinf'])
        nvarinfs_lin.append(res_tau[tau_w][i]['clust_res_lin']['sym']['nvarinf'])
        
    # order results by time points
    mostc_bestcs = [] # most common best partition
    mostc_bestcs_lin = []
    n_mcs = [] # number of most common best partition
    n_mcs_lin = []
    nclust_mbcs = [] # number of clusters of the most c. best partition
    nclust_mbcs_lin = []
    mean_nvarinfs = []
    mean_nvarinfs_mostc = []
    mean_nvarinfs_lin = []
    mean_nvarinfs_mostc_lin = []
    
    std_nvarinfs = []
    std_nvarinfs_mostc = []
    std_nvarinfs_lin = []
    std_nvarinfs_mostc_lin = []
    
    
    mostc_bestc_nmis = []
    mostc_bestc_nmis_lin = []
    for t in range(len(best_clusts[0])):
        best_c_t = [best_clusts[i][t] for i in range(len(best_clusts))]
        c = Counter([tuple(sorted([tuple(sorted(c)) for c in clust])) \
                                                 for clust in best_c_t])
        mbc = [set(d) for d in c.most_common(1)[0][0]]
        
        #num most common
        n_mcs.append(c.most_common(1)[0][1])
        
        nmis = np.array([norm_mutual_information(mbc, part1),
            norm_mutual_information(mbc, part2),
            norm_mutual_information(mbc, part3)])
        
    
        mostc_bestc_nmis.append(np.argwhere(np.array(nmis == 1)))
        
        
        # only syms that have the most common best cluster
        mostc_syms = [i for i in range(num_sym) if \
                      np.allclose(norm_mutual_information(best_clusts[i][t],mbc),1.0)]
        mostc_bestcs.append(mbc)
        nclust_mbcs.append(len(mbc))
        
        mean_nvarinfs.append(np.mean([nvarinfs[i][t] for i in range(num_sym)]))
        
        mean_nvarinfs_mostc.append(np.mean([nvarinfs[i][t] for i in mostc_syms]))
        
        std_nvarinfs.append(np.std([nvarinfs[i][t] for i in range(num_sym)]))
        std_nvarinfs_mostc.append(np.std([nvarinfs[i][t] for i in mostc_syms]))
        
        best_c_t_lin = [best_clusts[i][t] for i in range(len(best_clusts_lin))]
        c = Counter([tuple(sorted([tuple(sorted(c)) for c in clust])) \
                                                 for clust in best_c_t_lin])
        mbc_lin = [set(d) for d in c.most_common(1)[0][0]]
        
        n_mcs_lin.append(c.most_common(1)[0][1])
        
        nmis = np.array([norm_mutual_information(mbc_lin, part1),
            norm_mutual_information(mbc_lin, part2),
            norm_mutual_information(mbc_lin, part3)])
        
    
        mostc_bestc_nmis_lin.append(np.argwhere(np.array(nmis == 1)))
        
        
        mostc_syms_lin = [i for i in range(num_sym) if \
                          norm_mutual_information(best_clusts_lin[i][t],mbc_lin) == 1.0]
        
        mostc_bestcs_lin.append(mbc_lin)
        nclust_mbcs_lin.append(len(mbc_lin))        
        
        mean_nvarinfs_lin.append(np.mean([nvarinfs_lin[i][t] for i in range(num_sym)]))
        
        mean_nvarinfs_mostc_lin.append(np.mean([nvarinfs_lin[i][t] for i in mostc_syms_lin]))
        
        std_nvarinfs_lin.append(np.std([nvarinfs_lin[i][t] for i in range(num_sym)]))
        std_nvarinfs_mostc_lin.append(np.std([nvarinfs_lin[i][t] for i in mostc_syms]))
        
    aggr_res[tau_w] = {'best_clusts': best_clusts,
                       'best_clusts_lin' : best_clusts_lin,
                       'nvarinfs' : nvarinfs,
                       'nvarinfs_lin' : nvarinfs_lin,
                       'mostc_bestcs': mostc_bestcs,
                       'nclust_mbcs': nclust_mbcs,
                       'mean_nvarinfs' : mean_nvarinfs,
                       'mean_nvarinfs_mostc' : mean_nvarinfs_mostc,
                       'mostc_bestcs_lin': mostc_bestcs_lin,
                       'nclust_mbcs_lin': nclust_mbcs_lin,
                       'mean_nvarinfs_lin' : mean_nvarinfs_lin,
                       'mean_nvarinfs_mostc_lin' : mean_nvarinfs_mostc_lin,
                       'n_mcs' : n_mcs,
                       'n_mcs_lin' : n_mcs_lin,
                       'mostc_bestc_nmis' : mostc_bestc_nmis,
                       'mostc_bestc_nmis_lin' : mostc_bestc_nmis_lin,
                       'std_nvarinfs' : std_nvarinfs,
                       'std_nvarinfs_mostc' : std_nvarinfs_mostc,
                       'std_nvarinfs_lin' : std_nvarinfs_lin,
                       'std_nvarinfs_mostc_lin' : std_nvarinfs_mostc_lin
                       }
#%% block probabilites
inter_group_probs = sim_param['inter_group_probs']

block_probs = np.ones((net.num_nodes, net.num_nodes))
for i in range(0,inter_group_probs.shape[0]):
    for j in range(0,inter_group_probs.shape[1]):
        block_probs[i*3:(i+1)*3,j*3:(j+1)*3] = inter_group_probs[i,j]
    
np.fill_diagonal(block_probs,0)
block_probs = block_probs/block_probs.sum(1)[:,np.newaxis]

plt.matshow(np.log10(block_probs))


#%% plots aggregated results  fig for paper

mpl.rcParams['lines.linewidth'] = 1.1
mpl.rcParams['font.size'] = 12
        
from matplotlib.gridspec import GridSpec

lw = 1.1
markersize=4

legend_fontsize = 10
        

fig = plt.figure(figsize=(14,5))
gs = GridSpec(2, 8, figure=fig)

ax0 = fig.add_subplot(gs[:,0:2])
ax1 = fig.add_subplot(gs[:,2:5])
ax2 = fig.add_subplot(gs[0,5:8],sharex=ax1)
ax3 = fig.add_subplot(gs[1,5:8],sharex=ax1)

from matplotlib.colors import LogNorm
import matplotlib as mpl
from palettable import colorbrewer

seqcmap = mpl.colors.LinearSegmentedColormap.from_list('seq',
                                   colorbrewer.diverging.PuOr_11_r.mpl_colors[6:])

m = ax0.matshow(block_probs, cmap=seqcmap,
            norm=LogNorm(vmin=block_probs[block_probs>0].min(),
                                      vmax=block_probs[block_probs>0].max()))

ax0.set_xticks(np.arange(0,net.num_nodes,20))
ax0.set_yticks(np.arange(0,net.num_nodes,20))


plt.colorbar(m, ax=ax0, label='interaction probability',
             fraction=0.046, pad=0.04)
# ax = ax12.twinx()

ax2.ticklabel_format(axis='y',style='sci', scilimits=(-2,2))

ax1.plot((0,sim_param['t_end']), (27,27), '--k', lw=lw)

ax1.plot((0,sim_param['t_end']), (9,9), '--k', lw=lw)

ax1.plot((0,sim_param['t_end']), (3,3), '--k', lw=lw)




last_ix= 25


taus = sorted(tau_ws)


taus = [0.01,75,1000,2000]

colors = ['#D13204', "#613D94", "#D0730F", "#359455",]

y_shift = [-0.1,-0.05,0.05,0.1]

for i, tau_w in enumerate(taus):
    

    mask_nmc = np.array(aggr_res[tau_w]['n_mcs']) > 0
    mask_nmc_lin = np.array(aggr_res[tau_w]['n_mcs_lin']) > 0
    p = ax1.plot(np.hstack(([0],np.array(t2s)[mask_nmc])),
            np.hstack(([net.num_nodes],np.array(aggr_res[tau_w]['nclust_mbcs'])[mask_nmc])), 
            '--', label=r'$\tau_w$ = {:.4g}'.format(tau_w),
            lw=lw,markersize=markersize,
            color=colors[i])
    p = ax1.plot(np.hstack(([0],np.array(t2s)[mask_nmc_lin])),
            np.hstack(([net.num_nodes],np.array(aggr_res[tau_w]['nclust_mbcs_lin'])[mask_nmc_lin])), 
            'o', label=r'lin. approx. $\tau_w$ = {:.4g}'.format(tau_w).format(tau_w),
            lw=lw,markersize=markersize,
            color=colors[i])
    
    mask_nmc = np.array(aggr_res[tau_w]['n_mcs']) == 1
    mask_nmc_lin = np.array(aggr_res[tau_w]['n_mcs_lin']) == 1
    
    varinf = np.array(aggr_res[tau_w]['mean_nvarinfs_mostc'])
    varinf[mask_nmc] = np.array(aggr_res[tau_w]['mean_nvarinfs'])[mask_nmc]
    
    
    f = ax2.plot(np.hstack(([0],np.array(t2s),[sim_param['t_end']+1])),
                  np.hstack(([0],varinf,[0])), 
                  '-',  label=r'$\tau_w$ = {:.4g}'.format(tau_w),
                  color=colors[i])

    varinf_std = np.array(aggr_res[tau_w]['std_nvarinfs'])
    
    f2 = ax2.fill_between(np.hstack(([0],np.array(t2s),[sim_param['t_end']+1])),
                  np.hstack(([0],varinf,[0])) - np.hstack(([0],varinf_std,[0])),
                  np.hstack(([0],varinf,[0])) + np.hstack(([0],varinf_std,[0])),
                  color=colors[i],
                  alpha=0.1)
    
    

    
    g = ax3.plot(t2s, 
                 [x[0][0] + 1 + y_shift[i] if len(x)>0 else np.nan for x in aggr_res[tau_w]['mostc_bestc_nmis']], 'o', color=colors[i],
                 label=r'$\tau_w$ = {:.3g}'.format(tau_w),
                 markersize=markersize)    
       

ax1.set_xlim([0,900])
ax2.set_xlim([0,900])
ax3.set_xlim([0,900])

ylim2 = ax2.get_ylim()
ax2.set_ylim([0,ylim2[1]])

ylim1 = ax1.get_ylim()
ax1.set_ylim([0,ylim1[1]])

ax1.set_ylabel('Num. of clusters')
ax2.set_ylabel('Norm. Var. Inf.')




ax1.text(180, 11, '2nd level', fontsize=legend_fontsize)

ax1.text(180, 4.5, '3rd level', fontsize=legend_fontsize)

ax1.text(180, 29, '1st level', fontsize=legend_fontsize)


ax3.set_xlabel(r'$t_2 - t_1$ [a.u.]')
ax1.set_xlabel(r'$t_2 - t_1$ [a.u.]')

ax3.set_ylabel('level with NMI=1')


ax1.legend(fontsize=legend_fontsize)
ax2.legend(fontsize=legend_fontsize)

ax0.text(-0.2,1.1, r'\textbf{A}', transform=ax0.transAxes, useTex=True)
ax1.text(-0.1,1.05, r'\textbf{B}', transform=ax1.transAxes, useTex=True)
ax2.text(-0.1,1.05, r'\textbf{C}', transform=ax2.transAxes, useTex=True)
ax3.text(-0.1,1.05, r'\textbf{D}', transform=ax3.transAxes, useTex=True)

    #%%

plt.savefig(os.path.join(figdir, 'fig_hiera_clust.pdf'))
plt.savefig(os.path.join(figdir, 'fig_hiera_clust.png'), dpi=600)

